{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用torchvision的數據集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mlearn as mlearn\n",
    "from mlearn import functional as F\n",
    "from mlearn import layers\n",
    "from mlearn.optimizers import SGD, RMSProp, Momentum, Adam\n",
    "from torchvision import datasets\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mlearn.utils import DataLoader\n",
    "from mlearn.utils import pre_F as P\n",
    "import sys\n",
    "from time import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加載數據集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = datasets.MNIST(\"datasets\", train=True, download=True)\n",
    "\n",
    "test = datasets.MNIST(\"datasets\", train=False, download=True)\n",
    "\n",
    "pre = [P.normalize_MinMax]\n",
    "# pre = [P.normalize]\n",
    "# pre = []\n",
    "trainset = DataLoader((train.data,train.targets),batch_size=32,shuffle=True,\n",
    "                      preprocessing=pre)\n",
    "testset = DataLoader((test.data, test.targets), batch_size=32, shuffle=True,\n",
    "                    preprocessing=pre)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Simple_Net(mlearn.Module):\n",
    "    def __init__(self):\n",
    "        super(Simple_Net,self).__init__()\n",
    "        self.dense1 = layers.Dense(784,300)\n",
    "        self.dense2 = layers.Dense(300,10)\n",
    "\n",
    "        \n",
    "    def forward(self, inputs):\n",
    "        o = self.dense1(inputs)\n",
    "        o = F.tanh(o)\n",
    "        o = self.dense2(o)\n",
    "        o = F.relu(o)\n",
    "        return o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "网络参数加载完毕\n"
     ]
    }
   ],
   "source": [
    "class Net(mlearn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net,self).__init__()\n",
    "        self.dense1 = layers.Dense(784,500)\n",
    "        self.dense2 = layers.Dense(500,256)\n",
    "        self.dense3 = layers.Dense(256,128)\n",
    "        self.dense4 = layers.Dense(128, 64)\n",
    "        self.dense5 = layers.Dense(64 , 10)\n",
    "        \n",
    "    def forward(self, inputs):\n",
    "        o = self.dense1(inputs)\n",
    "        o = F.relu(o)\n",
    "        o = self.dense2(o)\n",
    "        o = F.relu(o)\n",
    "        o = self.dense3(o)\n",
    "        o = F.relu(o)\n",
    "        o = self.dense4(o)\n",
    "        o = F.leaky_relu(o,0.01)\n",
    "        o = self.dense5(o)\n",
    "        o = F.leaky_relu(o, 0.01)\n",
    "        return o\n",
    "net = Net()\n",
    "net.load_wb('saved_param/Param4Test.pkl')\n",
    "hist = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/10  [AVG Loss -> 0.44371]\n",
      "2/10  [AVG Loss -> 0.09050]\n",
      "3/10  [AVG Loss -> 0.06110]\n",
      "4/10  [AVG Loss -> 0.04394]\n",
      "5/10  [AVG Loss -> 0.03529]\n",
      "6/10  [AVG Loss -> 0.03109]\n",
      "7/10  [AVG Loss -> 0.02512]\n",
      "8/10  [AVG Loss -> 0.02570]\n",
      "9/10  [AVG Loss -> 0.02029]\n",
      "10/10  [AVG Loss -> 0.01907]\n",
      "Finished in 156.5337255001068\n",
      "trainning completed!\n"
     ]
    }
   ],
   "source": [
    "# 如果features没有进行规一化处理的话,并且激活函数全部使用relu家族会导致权重爆炸 (Weights -> Nan)\n",
    "# Solution: 规一化或者加入非线性激活函数可以解决 如 Tanh\n",
    "# Tanh 会使神经网络拟合速度降低, 并且速度不如 ReLU\n",
    "# 但是Tanh可以避免梯度消失,对于一些规模很小的神经网络\n",
    "\n",
    "def fit(hist):\n",
    "    optimizer = Adam(net,0.001)\n",
    "    EPOCHS = 10\n",
    "    start = time()\n",
    "    for epoch in range(EPOCHS):\n",
    "        running_loss = 0.0\n",
    "#         bar = \" \"*20\n",
    "        for i, batch in enumerate(trainset, 0):\n",
    "            features, labels = batch\n",
    "            net.zero_grad()\n",
    "            predict = net(features.reshape(-1,784))\n",
    "            loss = F.cross_entropy(predict, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.data\n",
    "#             print(f\"\\r{epoch+1}/{EPOCHS} Batch %-4d/1874  [{bar}] -> AVG Loss %.5f\"%(i,running_loss/(i+1)), end=\"\")\n",
    "#             sys.stdout.flush()\n",
    "        print(f\"{epoch+1}/{EPOCHS}  [AVG Loss -> %.5f]\"%(running_loss/(i+1)))\n",
    "        hist.append(running_loss/(i+1))\n",
    "    print(f\"Finished in {time() - start}\")\n",
    "    print('trainning completed!')\n",
    "fit(hist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 0.98003\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAaIklEQVR4nO3da2xcZ37f8e9/Zji8c4YSqRuHEuVrLJvkekPLjtMGi6yB2MjGLtDd1l6nSIsARoG4u2kCtO5tX7ivsimy7QsjiJGkKNr1OrvOtnA2TtxcdtEU6GpF26v7ytbKsniTRIn363Bm/n0xQ2okD8WRRPLMnPl9AEHn8nDmj4H0m4fPc85zzN0REZHqFwm6ABER2RwKdBGRkFCgi4iEhAJdRCQkFOgiIiERC+qNOzo6vKenJ6i3FxGpSu+9995Vd+8sdS6wQO/p6WFwcDCotxcRqUpm9sl65zTkIiISEgp0EZGQUKCLiISEAl1EJCQU6CIiIaFAFxEJCQW6iEhIVF2gD16Y4Hf+8ido2V8RkRtVXaCfHJnm93/wUy7NLAVdiohIRam6QO/rTgJwbGg64EpERCpL1QX6ob1txCLGiZGpoEsREakoVRfoDXVRHtjdyvFh9dBFRIpVXaAD9HcnOD48rYlREZEiVRnofakk04srXJxYCLoUEZGKUZWB3tuVAOCYhl1ERNZUZaA/uKeV+liEE8OaGBURWVWVgV4XjXBoX5t66CIiRaoy0AH6uhKcHJkmm9PEqIgIVHOgp5IspLOcH58LuhQRkYpQtYHe362JURGRYlUb6Ac7WmiORzmuiVEREaCKAz0aMR7pSuiOURGRgqoNdID+7iSnx2ZIZ3JBlyIiEriqDvTergTpTI4PL88GXYqISOCqOtD7U/mldDXsIiJSZqCb2dNmdtbMzpnZK7do90UzczMb2LwS19e9o5FkU50mRkVEKCPQzSwKvAY8AxwCXjCzQyXatQJfAY5sdpG3qI3eroQuXRQRobwe+mHgnLufd/c08CbwXIl2/xH4OrCtz4brTyX58PIsSyvZ7XxbEZGKU06gdwFDRfvDhWNrzOxRoNvdv7eJtZWlN5Ugm3NOjc5s91uLiFSUcgLdShxbW0DFzCLAN4Df3vCFzF4ys0EzGxwfHy+/ylu4PjGqcXQRqW3lBPow0F20nwJGi/ZbgUeAH5jZBeAJ4O1SE6Pu/rq7D7j7QGdn551XXWRPooFdrfWc0Di6iNS4cgL9KHC/mR00szjwPPD26kl3n3b3Dnfvcfce4IfAs+4+uCUVl9CXSnBMPXQRqXEbBrq7Z4CXgXeBM8C33f2Umb1qZs9udYHl6EslOX91ntmllaBLEREJTKycRu7+DvDOTce+tk7bz919WbenL5XAHU6OzPBz9+7c7rcXEakIVX2n6Ko+TYyKiIQj0Hc0x0m1N2oJABGpaaEIdMhfvnh8RD10EaldoQn03lSCoYlFJubTQZciIhKI0AR6Xyr/SDqNo4tIrQpNoPd25QNdNxiJSK0KTaC3NtRxT2ezVl4UkZoVmkCHwsSohlxEpEaFKtD7UgmuzC5zeWZbV/AVEakIoQt0gGND6qWLSO0JVaAf2psgGjHdYCQiNSlUgd4Yj/LA7laOjyjQRaT2hCrQAfq6EhwfnsLdN24sIhIi4Qv07gRTCysMTSwGXYqIyLYKXaCvPZJO67qISI0JXaA/sLuVeDSiiVERqTmhC/R4LMJD+9p06aKI1JzQBTrkJ0ZPjkyTy2liVERqRzgDPZVgPp3l/NW5oEsREdk2oQz0/u7VR9JpHF1EakcoA/3ezhaa4lEFuojUlFAGejRiPLIvwTGtvCgiNSSUgQ75cfTTozOsZHNBlyIisi1CG+i9qQTLmRwfXp4NuhQRkW0R2kBfu2NU4+giUiNCG+gHdjbR1hBToItIzQhtoJsZfXoknYjUkNAGOuQnRs9emmVpJRt0KSIiWy7kgZ4kk3POjM0EXYqIyJYLeaDnnzGqcXQRqQWhDvS9iQY6Wup1g5GI1IRQB7qZ0Z9KcEI9dBGpAaEOdMjfYHRufI655UzQpYiIbKnQB3p/Kok7nBxRL11Ewi30gd5bmBjVsIuIhF3oA72jpZ6uZKMmRkUk9EIf6JC/fFGXLopI2NVIoCe5OLHA1EI66FJERLZMWYFuZk+b2VkzO2dmr5Q4/8/N7ISZ/djM/q+ZHdr8Uu+cbjASkVqwYaCbWRR4DXgGOAS8UCKw33D3Xnf/DPB14Pc2vdK78EjXaqBrHF1EwqucHvph4Jy7n3f3NPAm8FxxA3cvXiylGfDNK/HuJRrruKejWT10EQm1WBltuoChov1h4PGbG5nZbwC/BcSBXyz1Qmb2EvASwP79+2+31rvSm0pw5PzEtr6niMh2KqeHbiWOfaoH7u6vufu9wL8G/n2pF3L31919wN0HOjs7b6/Su9SXSnJpZokrM0vb+r4iItulnEAfBrqL9lPA6C3avwn8g7spaitoYlREwq6cQD8K3G9mB80sDjwPvF3cwMzuL9r9ZeCjzStxczy8r42IaWJURMJrwzF0d8+Y2cvAu0AU+GN3P2VmrwKD7v428LKZPQWsAJPAr21l0XeiKR7jgd2tHFMPXURCqpxJUdz9HeCdm459rWj7q5tc15bo7UrwNz+5grtjVmpqQESketXEnaKr+rqTTMynGZ5cDLoUEZFNV1OB3q+JUREJsZoK9Af3tFIXNY6PaGJURMKnpgK9Phblob1tHB9SD11EwqemAh3y16OfHJkml6uo1QlERO5a7QV6V5LZ5QwfX5sPuhQRkU1Ve4HerZUXRSScai7Q7+tsobEuyjGNo4tIyNRcoMeiER7e18aJEQW6iIRLzQU65FdePDU6TSabC7oUEZFNU5OB3t+dYGklx0dX5oIuRURk09RkoPfqkXQiEkI1Geg9O5tpbYhp5UURCZWaDPRIxOhLJTihQBeREKnJQAfo7Uryk0szLGeyQZciIrIpajbQ+1MJVrLOmbHZoEsREdkUNRvofd1JAE5oYlREQqJmA31fooGdzXFNjIpIaNRsoJvlJ0Z16aKIhEXNBjpAbyrJuStzzC9ngi5FROSu1XSg96cS5BxOjc4EXYqIyF2r6UDvTemOUREJj5oO9F2tDexNNOih0SISCjUd6IAmRkUkNBToqSQXri0wvbASdCkiIndFgV4YR9cDL0Sk2inQu/J3jB7TsIuIVLmaD/REUx09O5s0ji4iVa/mAx3yNxhpKV0RqXYKdPI3GI1OLzE+uxx0KSIid0yBTv5KF9ANRiJS3RTowMP72ogYusFIRKqaAh1oro9x364W9dBFpKop0Av6UkmOD0/j7kGXIiJyRxToBX2pBNfm04xOLwVdiojIHVGgF6xNjA5p2EVEqpMCveChva3URU2PpBORqqVAL6iPRXlwTysnRtRDF5HqVFagm9nTZnbWzM6Z2Sslzv+WmZ02s+Nm9jdmdmDzS916qxOjuZwmRkWk+mwY6GYWBV4DngEOAS+Y2aGbmn0ADLh7H/AW8PXNLnQ79KcSzC5luHBtPuhSRERuWzk99MPAOXc/7+5p4E3gueIG7v59d18o7P4QSG1umdujt7DyopbSFZFqVE6gdwFDRfvDhWPr+XXgL0qdMLOXzGzQzAbHx8fLr3KbPLC7hYa6CMeGFOgiUn3KCXQrcazkILOZ/SowAPxuqfPu/rq7D7j7QGdnZ/lVbpNYNMLD+/RIOhGpTuUE+jDQXbSfAkZvbmRmTwH/DnjW3at22cLergSnRmfIZHNBlyIiclvKCfSjwP1mdtDM4sDzwNvFDczsUeAPyIf5lc0vc/v0dydYXMlybnwu6FJERG7LhoHu7hngZeBd4AzwbXc/ZWavmtmzhWa/C7QA3zGzH5vZ2+u8XMVbnRg9rnF0EakysXIaufs7wDs3Hfta0fZTm1xXYO7paKa1PsbxkSn+0WPdG/+AiEiF0J2iN4lEjEe6ElobXUSqjgK9hL5UgjNjMyxnskGXIiJSNgV6CX2pJCtZ5+yl2aBLEREpmwK9hL5UAkArL4pIVVGgl5Bqb6S9qY4TusFIRKqIAr0EM1tbeVFEpFoo0NfRn0rw4eVZFtKZoEsRESmLAn0dvakkOYfTozNBlyIiUhYF+jr6NTEqIlVGgb6OXW0N7Glr0MqLIlI1FOi30JtKcEI9dBGpEgr0W+hPJTh/dZ7pxZWgSxER2ZAC/Rb6UvmVF0/qkXQiUgUU6LfQ25WfGNX16CJSDRTot9DeHGf/jiZNjIpIVVCgb6AvpaV0RaQ6KNA30JdKMDK1yNW5qn1MqojUCAX6BlYnRnX5oohUOgX6Bh7pSmAGxzSOLiIVToG+gZb6GPd2tqiHLiIVT4Fehr5UgmPD07h70KWIiKxLgV6Gvq4EV+eWGZteCroUEZF1KdDL0NednxjV5YsiUskU6GU4tLeNWMR0g5GIVDQFehka6qI8sLtVPXQRqWgK9DL1dyc4PjyliVERqVgK9DL1pZLMLGX45NpC0KWIiJSkQC/T6sqLusFIRCqVAr1MD+5ppT4W0Q1GIlKxFOhlqotGOLSvTROjIlKxFOi3oa8rwcnRabI5TYyKSOVRoN+GvlSShXSWn47PBV2KiMinKNBvQ393YWJ0SBOjIlJ5FOi34WBHC83xqMbRRaQiKdBvQzRiPNKV4PiIAl1EKo8C/Tb1dyc5MzpDOpMLuhQRkRso0G9Tb1eCdDbH2UuzQZciInKDsgLdzJ42s7Nmds7MXilx/hfM7H0zy5jZFze/zMrRX3jG6PERTYyKSGXZMNDNLAq8BjwDHAJeMLNDNzW7CPxT4I3NLrDSdO9oJNlUx/EhjaOLSGWJldHmMHDO3c8DmNmbwHPA6dUG7n6hcC70A8tmRq8mRkWkApUz5NIFDBXtDxeO3TYze8nMBs1scHx8/E5eoiL0p5J8eHmWxXQ26FJERNaUE+hW4tgd3fvu7q+7+4C7D3R2dt7JS1SEvlSCbM45PaZeuohUjnICfRjoLtpPAaNbU0516EvpGaMiUnnKCfSjwP1mdtDM4sDzwNtbW1Zl25NoYFdrvQJdRCrKhoHu7hngZeBd4AzwbXc/ZWavmtmzAGb2mJkNA18C/sDMTm1l0ZXgM91J/uLkGP/hf53kzNhM0OWIiGBBPSNzYGDABwcHA3nvzTA0scA3/upDvndijHQmx2f3J/ny4wf4Qt9eGuqiQZcnIiFlZu+5+0DJcwr0uzM5n+ZP3x/mjSMXOX91nkRjHf/wsym+/Ph+7tvVEnR5IhIyCvRt4O78v/PXeOPIRd49dYmVrPP4wR28+MQBfunh3dTH1GsXkbt3q0Av58YiKYOZ8eS9HTx5bwdX55b5zuAwb/zoE77yrQ/Y0RznSwMpvnx4Pwd2NgddqoiElHroWyiXc/7u3FXeOPIJf33mCtmc8/fv7+DFx/fz+Yd2UxfV2mgicns05FIBLk0v8SdHh3jz6EXGppfobK3n+ce6ef7wfrqSjUGXJyJVQoFeQTLZHD84O84bP7rI989ewYDPPbiLFx/fz+ce3EU0UurGXBGRPAV6hRqeXCj02ocYn11mX6KB5w/v5x8/1s3utoagyxORCqRAr3Ar2Rx/ffoyb/zoIn/30VWiEeOph3bx4uMH+Hv3dRBRr11ECnSVS4Wri0Z4pncvz/Tu5cLVeb519CLfGRzm3VOX2b+jiRcO7+dLAyk6WuqDLlVEKph66BVqOZPl3VOX+eYPP+HIxxPURY1fengPLz5+gCfu2YGZeu0itUhDLlXu3JVZ3jgyxFvvDTGzlOGezma+fHg/X/zZFMmmeNDlicg2UqCHxNJKlj8/PsY3j3zC+xeniMcifKF3L7/ymX387IF22hrqgi5RRLaYAj2EzozN8MaRi/zPD0aYW85gBj+zp43DPe0M9Ozg8MEdulJGJIQU6CG2mM7ywcVJjl6Y5OiFCd6/OMlC4dF43TsaeezADh47uIPHetq5t7NFY+8iVU5XuYRYYzzKk/d18OR9HUD+xqXTYzP5gP94gv/z0Tjf/WAEgPamunzvvWcHAz3tPNKV0PIDIiGiHnrIuTsfX51nsNCDP3phggvXFgBoqIvwaHf7Wg/+0f3ttNTrO16kkmnIRW5wZXbphoA/PTpDziEaMQ7tbWOgp73Qi99BZ6uufRepJAp0uaW55QzvfzLJ4IUJfnRhgh8PTbG0kgPgYEczAwdWe/E76NnZpHF4kQAp0OW2pDM5To5O5wP+40kGP5lgamEFgI6Weh7raeexnnzAP7S3lZjG4UW2jQJd7kou5/x0fI6jF6734ocnFwFojkf57IF2Bg7s4P7dLXQlG9mXbKSjJa6evMgWUKDLphubXrwe8B9PcPbyLMX/lOKxCF3JxkLAN7CvsN2VbKSrvZE9iQY9lk/kDuiyRdl0exONPNvfyLP9+wCYXVphaGKRkalFRqfyf49MLTIyucgPzo5zZXb5hp83g86W+nzQtxeCP9FAV3sT+5INpJJNtDXG1MsXuQ0KdNkUrQ11HNpXx6F9bSXPL2eyXJpeYmRyNfSXGJlaYHRqidOjM/zV6cukM7kbfqY5HqWrvXGtd78v2UiqaH9Xa73G70WKKNBlW9THohzY2bzuQ7Ldnatz6bXe/ejUIsOT13v7x4ammCxMzK6KRow9bQ1rwzpd7Y3sTTSSbKqjtaGO1oYYrfWxte2meFQ9fgk1BbpUBDOjs7WeztZ6+ruTJdvML2cYm15kZCrf0y8e2jl6YZI/Oz5GNrf+nFA0YrTUx2htiNFSH6NtNfQb8qHfUrTdVny8/sZtPSZQKpUCXapGc32M+3a1ct+u1pLnszlnfHaZ6cUVZpdWmF3OMLuUyW8X/p5byh+bKeyPTS/x0ZXrbTK3+EJY1VIfKwr5678BFP9W0NZYR3tznJ3NcXY0x9nZEqe9Ka6lFmRLKdAlNKIRY0+igT2JO1tl0t1ZWsmt+2Uwu5Qp+lPYX15haiHN0MTC2pfE8k1zAcUSjXVrIb8a9Pnt+pLHdSWQ3A4FukiBmdEYj9IYj7LrLl4nnckxs7TCxHyaa3NpJubTTMwvc20+v31tPs3EXJpPri3w/sUpJhfS6w4VtdTHrof8auC3rG7X3/AbwM7mehrj+gKoZQp0kU0Wj0XoaKnPPwN298btczlnZmnleuDPrf59/UtgYj7N6PQSJ0enmZhPs5It/QXQWBe9oYffUh/Dyf/2kctBzp2cF/YL2zl33PNDVqvbuaLzXtRuvZ9da58r3T4ei5BsitPeVEd7U5zkTX+3N9cVzufbtDXU6eHod0CBLhKwSMRINsVJNsW5t3Pj9u7O7HKGibl0UeAvr/X8V38LuDaX5uK1BcwgYkbE7Pp2hMK+ESkcixbORyNGXcQ+dT5i3LRf9Hr26deLRK63X1rJMbWQZnJhhTOXZphayA9VrTdlEbH88NSN4V/4QmjOH9uxeqz5ertaH6JSoItUGTOjrSHfi+3pKH0ZaDXI5ZzZpQyTC2kmF9JMLawUtlcK4X99e2x6iTNjM0wurLC4kl33NZvi0dK/ATTlfwNojEfzX1hRIxqJEItY/k9hvy5iRAv7sUikaPvm/fzPRqNG3erxiAX+W4UCXUQCEYkYiaY6Ek119FD+F9PSSjYf9vMra73+/BdC8Xb+75GpRSYX0kwvrrAdq5xEjE99Eax9caztG7/51AP8SuEu682kQBeRqtJQF2VvIn8TWbmyOWdmcYWlTJZM1snmnEwuRybnN+4XtldyTrawn8nl/2RzOVbW2jrZbK7onLOSza2dyxTO5Y8XXquwn8k6yaateaC7Al1EQi8aMdqb40GXseV0l4OISEgo0EVEQkKBLiISEgp0EZGQKCvQzexpMztrZufM7JUS5+vN7E8K54+YWc9mFyoiIre2YaCbWRR4DXgGOAS8YGaHbmr268Cku98HfAP4nc0uVEREbq2cHvph4Jy7n3f3NPAm8NxNbZ4D/lth+y3g86YnCYiIbKtyAr0LGCraHy4cK9nG3TPANLDz5hcys5fMbNDMBsfHx++sYhERKamcG4tK9bRvvom2nDa4++vA6wBmNm5mn5Tx/qV0AFfv8GfDSJ/HjfR5XKfP4kZh+DwOrHeinEAfBrqL9lPA6Dpths0sBiSAiVu9qLuXsa5caWY26O4Dd/rzYaPP40b6PK7TZ3GjsH8e5Qy5HAXuN7ODZhYHngfevqnN28CvFba/CPyt+3YshSMiIqs27KG7e8bMXgbeBaLAH7v7KTN7FRh097eBPwL+u5mdI98zf34rixYRkU8ra3Eud38HeOemY18r2l4CvrS5pd3S69v4XtVAn8eN9Hlcp8/iRqH+PEwjIyIi4aBb/0VEQkKBLiISElUX6ButK1MrzKzbzL5vZmfM7JSZfTXomiqBmUXN7AMz+17QtQTNzJJm9paZ/aTw7+Tngq4pKGb2Lwv/T06a2bfMrCHomrZCVQV6mevK1IoM8Nvu/hDwBPAbNfxZFPsqcCboIirEfwH+0t1/BuinRj8XM+sCvgIMuPsj5K/WC+WVeFUV6JS3rkxNcPcxd3+/sD1L/j/rzUsy1BQzSwG/DPxh0LUEzczagF8gf0kx7p5296lgqwpUDGgs3PjYxKdvjgyFagv0ctaVqTmF5YofBY4EW0ng/jPwr4Bc0IVUgHuAceC/Foag/tDMmoMuKgjuPgL8J+AiMAZMu/v/DraqrVFtgV7WmjG1xMxagD8FftPdZ4KuJyhm9gXgiru/F3QtFSIGfBb4fXd/FJgHanLOyczayf8mfxDYBzSb2a8GW9XWqLZAL2ddmZphZnXkw/yb7v7doOsJ2M8Dz5rZBfJDcb9oZv8j2JICNQwMu/vqb21vkQ/4WvQU8LG7j7v7CvBd4MmAa9oS1Rbo5awrUxMK683/EXDG3X8v6HqC5u7/xt1T7t5D/t/F37p7KHth5XD3S8CQmT1YOPR54HSAJQXpIvCEmTUV/t98npBOEJd163+lWG9dmYDLCsrPA/8EOGFmPy4c+7eFZRpEAP4F8M1C5+c88M8CricQ7n7EzN4C3id/ddgHhHQJAN36LyISEtU25CIiIutQoIuIhIQCXUQkJBToIiIhoUAXEQkJBbqISEgo0EVEQuL/A3B0zmxrmcdjAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "for batch in testset:\n",
    "    features, labels = batch\n",
    "    o = net(features.reshape(-1,784))\n",
    "    predict = []\n",
    "    for x in o.data:\n",
    "        predict.append(np.argmax(x))\n",
    "    for b in predict == labels.data:\n",
    "        if b:\n",
    "            correct += 1\n",
    "    total += 32\n",
    "print(\"Accuracy %.5f\"%(correct / total))\n",
    "\n",
    "plt.plot(hist)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
